
import json
import datasets
from fire import Fire
from functools import partial
from typing import List
from loguru import logger
import os
import time
import sys
sys.path.append("..")
from utils import (
    generate_together,
    generate_with_references,
)
from datasets import load_dataset

def process_fn(
    item, 
    model, 
    reference_models = [],
    temperature=0.7,
    max_tokens=2048,
    rounds=1,
    n=1,
    aggPrompt="Default",
):
    instructions = item['prompt']
    messages = [{'role': 'user', 'content': instructions}]
    
    references = item.get('references', [])    

    if len(references) == 0 and len(reference_models) > 0:
        prev_references = []
        for i_round in range(rounds):
            references = []
            for reference_model in reference_models:
                reference = generate_with_references(
                    model=reference_model,
                    messages=messages,
                    references=prev_references,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    generate_fn=generate_together,
                )
                if reference is not None:
                    references.append(reference["outputs"][0])
            if i_round < rounds - 1:
                prev_references = references
                references = []     
    outputs = generate_with_references(
        model=model,
        messages=messages,
        references=references,
        temperature=temperature,
        n=n,
        max_tokens=max_tokens,
        aggPrompt=aggPrompt,
    )
    inputs = outputs["inputs"]
    outputs = outputs["outputs"]

    return {
        'input':inputs, 'output': outputs, 'generator': model + '-together'
    }


def main(
    model: str,
    output_path: str,
    additional_info: str = "",
    reference_paths: str = None,
    reference_models: str = None,
    num_reference_path: int = None,
    aggPrompt: str = "Default",
    temperature: float = 0.7,
    max_tokens: int = 2048,
    rounds: int = 1,
    num_proc: int = 16,
    n=1,
):

    if reference_paths is None:
        reference_paths = []
    else:
        if "*" in reference_paths:
            import glob
            reference_paths = glob.glob(reference_paths)
            reference_paths = sorted(reference_paths)
        else:
            reference_paths = reference_paths.split(',')

    if reference_models is None:
        reference_models = []
    else:
        reference_models = reference_models.split(',')
    

    eval_set = load_dataset("HuggingFaceH4/ultrafeedback_binarized")["train_prefs"]


    if len(reference_paths):
        num_reference_path = len(reference_paths) if num_reference_path is None else num_reference_path
        reference_paths = reference_paths[:num_reference_path]
        logger.info(f"`reference_paths` provided: {reference_paths}")        

        references = []
        for reference_path in reference_paths:
            with open(reference_path) as f:
                reference_responses = json.load(f)
                logger.info(f"Reading reference outputs: {reference_path} ({len(reference_responses)})")
                for i_reference_response, reference_response in enumerate(reference_responses):
                    if len(references) <= i_reference_response:
                        references.append([reference_response['output']])
                    else:
                        references[i_reference_response].append(reference_response['output'])

        eval_set = eval_set.add_column(f"references", references)

    elif len(reference_models):

        logger.info(f"`reference_models` provided: {reference_models}, {len(reference_models)} of them. Will generate reference responses on-the-fly.")
    
    logger.info(f"Start.")
    logger.info(eval_set)
    # eval_set = eval_set.select(range(3))

    eval_set = eval_set.map(
        partial(
            process_fn, 
            model=model, 
            reference_models=reference_models,
            temperature=temperature,
            max_tokens=max_tokens,
            rounds=rounds,
            n=n,
            aggPrompt=aggPrompt,
        ),
        batched=False, num_proc=num_proc,
    )
    model_name = model.split('/')[-1]
    output_dir = f'{output_path}/{model_name}/'
    os.makedirs(output_dir, exist_ok=True)
    # print(eval_set)


    try:
        eval_set = eval_set.remove_columns(["chosen", "rejected", "messages", "score_chosen", "score_rejected"])
        eval_set = eval_set.remove_columns(f"references")
    except Exception as e:
        pass
    eval_set_list = list(eval_set)
    for i in range(n):
        if n == 1:
            num_reference_path_str = f"-num_reference_path{num_reference_path}" if num_reference_path is not None else ""
            output_path = f'{output_dir}/{model_name}-round_{rounds}-temp{temperature}{num_reference_path_str}_{additional_info}.json'
        else:
            output_path = f'{output_dir}/{model_name}-round_{rounds}-temp{temperature}-{i}_{n}_{additional_info}.json'
        
        logger.info(f"Saving outputs to {output_path}.")

        new_eval_set = []
        for item in eval_set_list:
            new_item = item.copy()
            new_item['output'] = item['output'][i]
            new_item['input'] = item['input'][i]
            new_eval_set.append(new_item)

        with open(output_path, 'w') as f:
            json.dump(new_eval_set, f, indent=2)


if __name__ == '__main__':

    Fire(main)